from lib.algorithms import ALGORITHMS
from .baseline import Baseline
from .baseline_loss import BaselineLoss
from .baseline_metric import BaselineAcc
from lib.sampler import NormalSampler
from lib.sampler import EpisodeSampler


@ALGORITHMS.register('baseline')
def baseline(cfg):
    return {
        'net': Baseline(cfg),
        'loss': BaselineLoss(cfg),
        'metric': BaselineAcc(cfg),
        'train_sampler': NormalSampler,
        'validate_sampler': NormalSampler,
        'test_sampler': EpisodeSampler
    }
